{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:28:17.104938Z",
     "start_time": "2019-11-06T14:28:16.182891Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div class=\"bk-root\">\n",
       "        <a href=\"https://bokeh.pydata.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
       "        <span id=\"1001\">Loading BokehJS ...</span>\n",
       "    </div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "(function(root) {\n",
       "  function now() {\n",
       "    return new Date();\n",
       "  }\n",
       "\n",
       "  var force = true;\n",
       "\n",
       "  if (typeof (root._bokeh_onload_callbacks) === \"undefined\" || force === true) {\n",
       "    root._bokeh_onload_callbacks = [];\n",
       "    root._bokeh_is_loading = undefined;\n",
       "  }\n",
       "\n",
       "  var JS_MIME_TYPE = 'application/javascript';\n",
       "  var HTML_MIME_TYPE = 'text/html';\n",
       "  var EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n",
       "  var CLASS_NAME = 'output_bokeh rendered_html';\n",
       "\n",
       "  /**\n",
       "   * Render data to the DOM node\n",
       "   */\n",
       "  function render(props, node) {\n",
       "    var script = document.createElement(\"script\");\n",
       "    node.appendChild(script);\n",
       "  }\n",
       "\n",
       "  /**\n",
       "   * Handle when an output is cleared or removed\n",
       "   */\n",
       "  function handleClearOutput(event, handle) {\n",
       "    var cell = handle.cell;\n",
       "\n",
       "    var id = cell.output_area._bokeh_element_id;\n",
       "    var server_id = cell.output_area._bokeh_server_id;\n",
       "    // Clean up Bokeh references\n",
       "    if (id != null && id in Bokeh.index) {\n",
       "      Bokeh.index[id].model.document.clear();\n",
       "      delete Bokeh.index[id];\n",
       "    }\n",
       "\n",
       "    if (server_id !== undefined) {\n",
       "      // Clean up Bokeh references\n",
       "      var cmd = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n",
       "      cell.notebook.kernel.execute(cmd, {\n",
       "        iopub: {\n",
       "          output: function(msg) {\n",
       "            var id = msg.content.text.trim();\n",
       "            if (id in Bokeh.index) {\n",
       "              Bokeh.index[id].model.document.clear();\n",
       "              delete Bokeh.index[id];\n",
       "            }\n",
       "          }\n",
       "        }\n",
       "      });\n",
       "      // Destroy server and session\n",
       "      var cmd = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n",
       "      cell.notebook.kernel.execute(cmd);\n",
       "    }\n",
       "  }\n",
       "\n",
       "  /**\n",
       "   * Handle when a new output is added\n",
       "   */\n",
       "  function handleAddOutput(event, handle) {\n",
       "    var output_area = handle.output_area;\n",
       "    var output = handle.output;\n",
       "\n",
       "    // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n",
       "    if ((output.output_type != \"display_data\") || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n",
       "      return\n",
       "    }\n",
       "\n",
       "    var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n",
       "\n",
       "    if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n",
       "      toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n",
       "      // store reference to embed id on output_area\n",
       "      output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n",
       "    }\n",
       "    if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n",
       "      var bk_div = document.createElement(\"div\");\n",
       "      bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n",
       "      var script_attrs = bk_div.children[0].attributes;\n",
       "      for (var i = 0; i < script_attrs.length; i++) {\n",
       "        toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n",
       "      }\n",
       "      // store reference to server id on output_area\n",
       "      output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n",
       "    }\n",
       "  }\n",
       "\n",
       "  function register_renderer(events, OutputArea) {\n",
       "\n",
       "    function append_mime(data, metadata, element) {\n",
       "      // create a DOM node to render to\n",
       "      var toinsert = this.create_output_subarea(\n",
       "        metadata,\n",
       "        CLASS_NAME,\n",
       "        EXEC_MIME_TYPE\n",
       "      );\n",
       "      this.keyboard_manager.register_events(toinsert);\n",
       "      // Render to node\n",
       "      var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n",
       "      render(props, toinsert[toinsert.length - 1]);\n",
       "      element.append(toinsert);\n",
       "      return toinsert\n",
       "    }\n",
       "\n",
       "    /* Handle when an output is cleared or removed */\n",
       "    events.on('clear_output.CodeCell', handleClearOutput);\n",
       "    events.on('delete.Cell', handleClearOutput);\n",
       "\n",
       "    /* Handle when a new output is added */\n",
       "    events.on('output_added.OutputArea', handleAddOutput);\n",
       "\n",
       "    /**\n",
       "     * Register the mime type and append_mime function with output_area\n",
       "     */\n",
       "    OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n",
       "      /* Is output safe? */\n",
       "      safe: true,\n",
       "      /* Index of renderer in `output_area.display_order` */\n",
       "      index: 0\n",
       "    });\n",
       "  }\n",
       "\n",
       "  // register the mime type if in Jupyter Notebook environment and previously unregistered\n",
       "  if (root.Jupyter !== undefined) {\n",
       "    var events = require('base/js/events');\n",
       "    var OutputArea = require('notebook/js/outputarea').OutputArea;\n",
       "\n",
       "    if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n",
       "      register_renderer(events, OutputArea);\n",
       "    }\n",
       "  }\n",
       "\n",
       "  \n",
       "  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n",
       "    root._bokeh_timeout = Date.now() + 5000;\n",
       "    root._bokeh_failed_load = false;\n",
       "  }\n",
       "\n",
       "  var NB_LOAD_WARNING = {'data': {'text/html':\n",
       "     \"<div style='background-color: #fdd'>\\n\"+\n",
       "     \"<p>\\n\"+\n",
       "     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
       "     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
       "     \"</p>\\n\"+\n",
       "     \"<ul>\\n\"+\n",
       "     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
       "     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
       "     \"</ul>\\n\"+\n",
       "     \"<code>\\n\"+\n",
       "     \"from bokeh.resources import INLINE\\n\"+\n",
       "     \"output_notebook(resources=INLINE)\\n\"+\n",
       "     \"</code>\\n\"+\n",
       "     \"</div>\"}};\n",
       "\n",
       "  function display_loaded() {\n",
       "    var el = document.getElementById(\"1001\");\n",
       "    if (el != null) {\n",
       "      el.textContent = \"BokehJS is loading...\";\n",
       "    }\n",
       "    if (root.Bokeh !== undefined) {\n",
       "      if (el != null) {\n",
       "        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n",
       "      }\n",
       "    } else if (Date.now() < root._bokeh_timeout) {\n",
       "      setTimeout(display_loaded, 100)\n",
       "    }\n",
       "  }\n",
       "\n",
       "\n",
       "  function run_callbacks() {\n",
       "    try {\n",
       "      root._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n",
       "    }\n",
       "    finally {\n",
       "      delete root._bokeh_onload_callbacks\n",
       "    }\n",
       "    console.info(\"Bokeh: all callbacks have finished\");\n",
       "  }\n",
       "\n",
       "  function load_libs(js_urls, callback) {\n",
       "    root._bokeh_onload_callbacks.push(callback);\n",
       "    if (root._bokeh_is_loading > 0) {\n",
       "      console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
       "      return null;\n",
       "    }\n",
       "    if (js_urls == null || js_urls.length === 0) {\n",
       "      run_callbacks();\n",
       "      return null;\n",
       "    }\n",
       "    console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
       "    root._bokeh_is_loading = js_urls.length;\n",
       "    for (var i = 0; i < js_urls.length; i++) {\n",
       "      var url = js_urls[i];\n",
       "      var s = document.createElement('script');\n",
       "      s.src = url;\n",
       "      s.async = false;\n",
       "      s.onreadystatechange = s.onload = function() {\n",
       "        root._bokeh_is_loading--;\n",
       "        if (root._bokeh_is_loading === 0) {\n",
       "          console.log(\"Bokeh: all BokehJS libraries loaded\");\n",
       "          run_callbacks()\n",
       "        }\n",
       "      };\n",
       "      s.onerror = function() {\n",
       "        console.warn(\"failed to load library \" + url);\n",
       "      };\n",
       "      console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
       "      document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
       "    }\n",
       "  };var element = document.getElementById(\"1001\");\n",
       "  if (element == null) {\n",
       "    console.log(\"Bokeh: ERROR: autoload.js configured with elementid '1001' but no matching script tag was found. \")\n",
       "    return false;\n",
       "  }\n",
       "\n",
       "  var js_urls = [\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-gl-1.0.0.min.js\"];\n",
       "\n",
       "  var inline_js = [\n",
       "    function(Bokeh) {\n",
       "      Bokeh.set_log_level(\"info\");\n",
       "    },\n",
       "    \n",
       "    function(Bokeh) {\n",
       "      \n",
       "    },\n",
       "    function(Bokeh) {\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n",
       "      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n",
       "      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n",
       "    }\n",
       "  ];\n",
       "\n",
       "  function run_inline_js() {\n",
       "    \n",
       "    if ((root.Bokeh !== undefined) || (force === true)) {\n",
       "      for (var i = 0; i < inline_js.length; i++) {\n",
       "        inline_js[i].call(root, root.Bokeh);\n",
       "      }if (force === true) {\n",
       "        display_loaded();\n",
       "      }} else if (Date.now() < root._bokeh_timeout) {\n",
       "      setTimeout(run_inline_js, 100);\n",
       "    } else if (!root._bokeh_failed_load) {\n",
       "      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
       "      root._bokeh_failed_load = true;\n",
       "    } else if (force !== true) {\n",
       "      var cell = $(document.getElementById(\"1001\")).parents('.cell').data().cell;\n",
       "      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
       "    }\n",
       "\n",
       "  }\n",
       "\n",
       "  if (root._bokeh_is_loading === 0) {\n",
       "    console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
       "    run_inline_js();\n",
       "  } else {\n",
       "    load_libs(js_urls, function() {\n",
       "      console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n",
       "      run_inline_js();\n",
       "    });\n",
       "  }\n",
       "}(window));"
      ],
      "application/vnd.bokehjs_load.v0+json": "\n(function(root) {\n  function now() {\n    return new Date();\n  }\n\n  var force = true;\n\n  if (typeof (root._bokeh_onload_callbacks) === \"undefined\" || force === true) {\n    root._bokeh_onload_callbacks = [];\n    root._bokeh_is_loading = undefined;\n  }\n\n  \n\n  \n  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n    root._bokeh_timeout = Date.now() + 5000;\n    root._bokeh_failed_load = false;\n  }\n\n  var NB_LOAD_WARNING = {'data': {'text/html':\n     \"<div style='background-color: #fdd'>\\n\"+\n     \"<p>\\n\"+\n     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n     \"</p>\\n\"+\n     \"<ul>\\n\"+\n     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n     \"</ul>\\n\"+\n     \"<code>\\n\"+\n     \"from bokeh.resources import INLINE\\n\"+\n     \"output_notebook(resources=INLINE)\\n\"+\n     \"</code>\\n\"+\n     \"</div>\"}};\n\n  function display_loaded() {\n    var el = document.getElementById(\"1001\");\n    if (el != null) {\n      el.textContent = \"BokehJS is loading...\";\n    }\n    if (root.Bokeh !== undefined) {\n      if (el != null) {\n        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n      }\n    } else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(display_loaded, 100)\n    }\n  }\n\n\n  function run_callbacks() {\n    try {\n      root._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n    }\n    finally {\n      delete root._bokeh_onload_callbacks\n    }\n    console.info(\"Bokeh: all callbacks have finished\");\n  }\n\n  function load_libs(js_urls, callback) {\n    root._bokeh_onload_callbacks.push(callback);\n    if (root._bokeh_is_loading > 0) {\n      console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n      return null;\n    }\n    if (js_urls == null || js_urls.length === 0) {\n      run_callbacks();\n      return null;\n    }\n    console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n    root._bokeh_is_loading = js_urls.length;\n    for (var i = 0; i < js_urls.length; i++) {\n      var url = js_urls[i];\n      var s = document.createElement('script');\n      s.src = url;\n      s.async = false;\n      s.onreadystatechange = s.onload = function() {\n        root._bokeh_is_loading--;\n        if (root._bokeh_is_loading === 0) {\n          console.log(\"Bokeh: all BokehJS libraries loaded\");\n          run_callbacks()\n        }\n      };\n      s.onerror = function() {\n        console.warn(\"failed to load library \" + url);\n      };\n      console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n      document.getElementsByTagName(\"head\")[0].appendChild(s);\n    }\n  };var element = document.getElementById(\"1001\");\n  if (element == null) {\n    console.log(\"Bokeh: ERROR: autoload.js configured with elementid '1001' but no matching script tag was found. \")\n    return false;\n  }\n\n  var js_urls = [\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.js\", \"https://cdn.pydata.org/bokeh/release/bokeh-gl-1.0.0.min.js\"];\n\n  var inline_js = [\n    function(Bokeh) {\n      Bokeh.set_log_level(\"info\");\n    },\n    \n    function(Bokeh) {\n      \n    },\n    function(Bokeh) {\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-1.0.0.min.css\");\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-widgets-1.0.0.min.css\");\n      console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n      Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-tables-1.0.0.min.css\");\n    }\n  ];\n\n  function run_inline_js() {\n    \n    if ((root.Bokeh !== undefined) || (force === true)) {\n      for (var i = 0; i < inline_js.length; i++) {\n        inline_js[i].call(root, root.Bokeh);\n      }if (force === true) {\n        display_loaded();\n      }} else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(run_inline_js, 100);\n    } else if (!root._bokeh_failed_load) {\n      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n      root._bokeh_failed_load = true;\n    } else if (force !== true) {\n      var cell = $(document.getElementById(\"1001\")).parents('.cell').data().cell;\n      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n    }\n\n  }\n\n  if (root._bokeh_is_loading === 0) {\n    console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n    run_inline_js();\n  } else {\n    load_libs(js_urls, function() {\n      console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n      run_inline_js();\n    });\n  }\n}(window));"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from bokeh.io import show, output_notebook\n",
    "from bokeh.plotting import figure, gridplot\n",
    "from bokeh.models import LinearAxis, Range1d\n",
    "output_notebook()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPU\n",
    "指定使用的 GPU 编号。  \n",
    "`watch -n 1 nvidia-smi` 实时查看 GPU 的运行状态。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:28:17.221752Z",
     "start_time": "2019-11-06T14:28:17.106472Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.set_device(\"cuda:3\")\n",
    "torch.cuda.current_device()\n",
    "# device = torch.device(\"cuda:5\")\n",
    "# xxx.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-02T15:34:18.036850Z",
     "start_time": "2019-11-02T15:34:18.032937Z"
    }
   },
   "source": [
    "### Data\n",
    "通过`torchvision.datasets`下载`MNIST`数据。  \n",
    "训练集：`train=True`  \n",
    "测试集：`train=False`  \n",
    "常用的还有`torchvision.datasets.ImageFolder()`，按文件夹取图片。  \n",
    "\n",
    "`torchvision.transforms`可以对图片做处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:28:17.375772Z",
     "start_time": "2019-11-06T14:28:17.223355Z"
    }
   },
   "outputs": [],
   "source": [
    "train_dataset = dsets.MNIST(root='../dataset', train=True, transform=transforms.ToTensor(), download=True)\n",
    "test_dataset = dsets.MNIST(root='../dataset', train=False, transform=transforms.ToTensor(), download=True)\n",
    "\n",
    "batch_size = 100\n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model\n",
    "用 LSTM 构造 RNN。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里定义的 LSTM 满足 `batch_first = True`, 即输入和输出的 Tensor 的维度中，`batch_size` 要放在第一个。  \n",
    "输入的 `Tensor(batch_size, seq_len, dim)`  \n",
    "否则，当 `batch_first = True`，则有 `Tensor(seq_len, batch_size, dim)`，通过 `X = X.transpose(0,1)`，即交换前两维数据得到。  \n",
    "\n",
    "作为单元组成部分的 Hidden 和 Cell 向量也有对应的维度，为 `Tensor(layer_size, batch_size, dim_hidden)`.\n",
    "\n",
    "经过 LSTM 输出的结果，只要保留最后一个时刻的结果，即在 `seq_len` 的一栏填 `-1`. 使其成为 `Tensor(batch_size, dim)`，再通过线性网络输出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:28:17.418508Z",
     "start_time": "2019-11-06T14:28:17.377557Z"
    }
   },
   "outputs": [],
   "source": [
    "class RNN(nn.Module):\n",
    "    def __init__(self, dim_in, dim_hid, dim_out, layers_size):\n",
    "        super().__init__()\n",
    "        self.dim_hid = dim_hid\n",
    "        self.layers_size = layers_size\n",
    "        self.lstm = nn.LSTM(dim_in, dim_hid, layers_size, batch_first=True)\n",
    "        self.fc = nn.Linear(dim_hid, dim_out)\n",
    "    def forward(self, x, batch_size):\n",
    "        h0 = torch.zeros(self.layers_size, batch_size, self.dim_hid).cuda()\n",
    "        c0 = torch.zeros(self.layers_size, batch_size, self.dim_hid).cuda()\n",
    "        x, _ = self.lstm(x, (h0, c0))\n",
    "        return self.fc(x[:, -1, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:28:23.498114Z",
     "start_time": "2019-11-06T14:28:17.419888Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RNN(\n",
       "  (lstm): LSTM(28, 128, num_layers=2, batch_first=True)\n",
       "  (fc): Linear(in_features=128, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lrate = 0.01\n",
    "epochs = 2\n",
    "sequence_length = 28\n",
    "dim_in = 28\n",
    "dim_hid = 128\n",
    "layers_size = 2\n",
    "num_classes = 10\n",
    "\n",
    "model = RNN(dim_in, dim_hid, num_classes, layers_size).cuda()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optim = torch.optim.Adam(model.parameters(), lr=lrate)\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意：每次反向传播的时候都需要将参数的梯度归零。  \n",
    "`optim.step()`则在每个`Variable`的`grad`都被计算出来后，更新每个`Variable`的数值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在每次训练中都用`train_loader`中的一个`batch`作为训练数据。  \n",
    "`Tensor.cuda()` 每个 `batch` 在实际使用之前，都先移入 `GPU` 后进行计算。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:28:45.028120Z",
     "start_time": "2019-11-06T14:28:23.500606Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "  <div class=\"bk-root\" id=\"bde10142-f5bb-4f8d-999b-e2032e8006b0\"></div>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/javascript": [
       "(function(root) {\n",
       "  function embed_document(root) {\n",
       "    \n",
       "  var docs_json = {\"95126b35-e60c-405f-bbe2-c27dac4c0980\":{\"roots\":{\"references\":[{\"attributes\":{\"below\":[{\"id\":\"1011\",\"type\":\"LinearAxis\"}],\"left\":[{\"id\":\"1016\",\"type\":\"LinearAxis\"}],\"renderers\":[{\"id\":\"1011\",\"type\":\"LinearAxis\"},{\"id\":\"1015\",\"type\":\"Grid\"},{\"id\":\"1016\",\"type\":\"LinearAxis\"},{\"id\":\"1020\",\"type\":\"Grid\"},{\"id\":\"1029\",\"type\":\"BoxAnnotation\"},{\"id\":\"1039\",\"type\":\"GlyphRenderer\"}],\"title\":{\"id\":\"1042\",\"type\":\"Title\"},\"toolbar\":{\"id\":\"1027\",\"type\":\"Toolbar\"},\"x_range\":{\"id\":\"1003\",\"type\":\"DataRange1d\"},\"x_scale\":{\"id\":\"1007\",\"type\":\"LinearScale\"},\"y_range\":{\"id\":\"1005\",\"type\":\"DataRange1d\"},\"y_scale\":{\"id\":\"1009\",\"type\":\"LinearScale\"}},\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},{\"attributes\":{\"callback\":null,\"data\":{\"x\":[0,1,2,3,4,5,6,7,8,9,10,11],\"y\":[2.3057875633239746,0.6396551728248596,0.3735741078853607,0.2768346667289734,0.1986483782529831,0.059680212289094925,0.1027182936668396,0.14535784721374512,0.06601426005363464,0.05358972027897835,0.027681956067681313,0.020977307111024857]},\"selected\":{\"id\":\"1048\",\"type\":\"Selection\"},\"selection_policy\":{\"id\":\"1047\",\"type\":\"UnionRenderers\"}},\"id\":\"1036\",\"type\":\"ColumnDataSource\"},{\"attributes\":{\"active_drag\":\"auto\",\"active_inspect\":\"auto\",\"active_multi\":null,\"active_scroll\":\"auto\",\"active_tap\":\"auto\",\"tools\":[{\"id\":\"1021\",\"type\":\"PanTool\"},{\"id\":\"1022\",\"type\":\"WheelZoomTool\"},{\"id\":\"1023\",\"type\":\"BoxZoomTool\"},{\"id\":\"1024\",\"type\":\"SaveTool\"},{\"id\":\"1025\",\"type\":\"ResetTool\"},{\"id\":\"1026\",\"type\":\"HelpTool\"}]},\"id\":\"1027\",\"type\":\"Toolbar\"},{\"attributes\":{},\"id\":\"1007\",\"type\":\"LinearScale\"},{\"attributes\":{\"bottom_units\":\"screen\",\"fill_alpha\":{\"value\":0.5},\"fill_color\":{\"value\":\"lightgrey\"},\"left_units\":\"screen\",\"level\":\"overlay\",\"line_alpha\":{\"value\":1.0},\"line_color\":{\"value\":\"black\"},\"line_dash\":[4,4],\"line_width\":{\"value\":2},\"plot\":null,\"render_mode\":\"css\",\"right_units\":\"screen\",\"top_units\":\"screen\"},\"id\":\"1029\",\"type\":\"BoxAnnotation\"},{\"attributes\":{},\"id\":\"1009\",\"type\":\"LinearScale\"},{\"attributes\":{\"data_source\":{\"id\":\"1036\",\"type\":\"ColumnDataSource\"},\"glyph\":{\"id\":\"1037\",\"type\":\"Line\"},\"hover_glyph\":null,\"muted_glyph\":null,\"nonselection_glyph\":{\"id\":\"1038\",\"type\":\"Line\"},\"selection_glyph\":null,\"view\":{\"id\":\"1040\",\"type\":\"CDSView\"}},\"id\":\"1039\",\"type\":\"GlyphRenderer\"},{\"attributes\":{\"formatter\":{\"id\":\"1044\",\"type\":\"BasicTickFormatter\"},\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1012\",\"type\":\"BasicTicker\"}},\"id\":\"1011\",\"type\":\"LinearAxis\"},{\"attributes\":{\"source\":{\"id\":\"1036\",\"type\":\"ColumnDataSource\"}},\"id\":\"1040\",\"type\":\"CDSView\"},{\"attributes\":{},\"id\":\"1012\",\"type\":\"BasicTicker\"},{\"attributes\":{},\"id\":\"1047\",\"type\":\"UnionRenderers\"},{\"attributes\":{\"plot\":null,\"text\":\"\"},\"id\":\"1042\",\"type\":\"Title\"},{\"attributes\":{\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1012\",\"type\":\"BasicTicker\"}},\"id\":\"1015\",\"type\":\"Grid\"},{\"attributes\":{},\"id\":\"1044\",\"type\":\"BasicTickFormatter\"},{\"attributes\":{\"formatter\":{\"id\":\"1046\",\"type\":\"BasicTickFormatter\"},\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1017\",\"type\":\"BasicTicker\"}},\"id\":\"1016\",\"type\":\"LinearAxis\"},{\"attributes\":{},\"id\":\"1046\",\"type\":\"BasicTickFormatter\"},{\"attributes\":{},\"id\":\"1017\",\"type\":\"BasicTicker\"},{\"attributes\":{\"dimension\":1,\"plot\":{\"id\":\"1002\",\"subtype\":\"Figure\",\"type\":\"Plot\"},\"ticker\":{\"id\":\"1017\",\"type\":\"BasicTicker\"}},\"id\":\"1020\",\"type\":\"Grid\"},{\"attributes\":{},\"id\":\"1048\",\"type\":\"Selection\"},{\"attributes\":{\"line_alpha\":0.1,\"line_color\":\"#1f77b4\",\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1038\",\"type\":\"Line\"},{\"attributes\":{\"line_color\":\"#1f77b4\",\"x\":{\"field\":\"x\"},\"y\":{\"field\":\"y\"}},\"id\":\"1037\",\"type\":\"Line\"},{\"attributes\":{},\"id\":\"1021\",\"type\":\"PanTool\"},{\"attributes\":{},\"id\":\"1022\",\"type\":\"WheelZoomTool\"},{\"attributes\":{\"overlay\":{\"id\":\"1029\",\"type\":\"BoxAnnotation\"}},\"id\":\"1023\",\"type\":\"BoxZoomTool\"},{\"attributes\":{\"callback\":null},\"id\":\"1003\",\"type\":\"DataRange1d\"},{\"attributes\":{},\"id\":\"1024\",\"type\":\"SaveTool\"},{\"attributes\":{},\"id\":\"1025\",\"type\":\"ResetTool\"},{\"attributes\":{\"callback\":null},\"id\":\"1005\",\"type\":\"DataRange1d\"},{\"attributes\":{},\"id\":\"1026\",\"type\":\"HelpTool\"}],\"root_ids\":[\"1002\"]},\"title\":\"Bokeh Application\",\"version\":\"1.0.0\"}};\n",
       "  var render_items = [{\"docid\":\"95126b35-e60c-405f-bbe2-c27dac4c0980\",\"roots\":{\"1002\":\"bde10142-f5bb-4f8d-999b-e2032e8006b0\"}}];\n",
       "  root.Bokeh.embed.embed_items_notebook(docs_json, render_items);\n",
       "\n",
       "  }\n",
       "  if (root.Bokeh !== undefined) {\n",
       "    embed_document(root);\n",
       "  } else {\n",
       "    var attempts = 0;\n",
       "    var timer = setInterval(function(root) {\n",
       "      if (root.Bokeh !== undefined) {\n",
       "        embed_document(root);\n",
       "        clearInterval(timer);\n",
       "      }\n",
       "      attempts++;\n",
       "      if (attempts > 100) {\n",
       "        console.log(\"Bokeh: ERROR: Unable to run BokehJS code because BokehJS library is missing\");\n",
       "        clearInterval(timer);\n",
       "      }\n",
       "    }, 10, root)\n",
       "  }\n",
       "})(window);"
      ],
      "application/vnd.bokehjs_exec.v0+json": ""
     },
     "metadata": {
      "application/vnd.bokehjs_exec.v0+json": {
       "id": "1002"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "result = []\n",
    "for e in range(epochs):\n",
    "    for i, (inputs, targets) in enumerate(train_loader):\n",
    "        inputs = inputs.squeeze(1).cuda()\n",
    "        targets = targets.cuda()\n",
    "        optim.zero_grad() \n",
    "        outputs = model(inputs, batch_size)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        if i % 100 == 0:\n",
    "            result.append(float(loss))\n",
    "fig = figure()\n",
    "fig.line(range(len(result)), result)\n",
    "show(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`re = torch.max(Tensor,dim)`, 返回的re为一个二维向量，其中`re[0]`为最大值的`Tensor`，re[1]为最大值对应的`index`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:28:59.488625Z",
     "start_time": "2019-11-06T14:28:45.029756Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the model on the 10000 test images: 97.21 %\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "wrong_count = 0\n",
    "wrong_classify = []\n",
    "for i, (inputs, targets) in enumerate(test_loader):\n",
    "    inputs = inputs.squeeze(1).cuda()\n",
    "    outputs = model(inputs, 1)\n",
    "    _, preds = torch.max(outputs.data, 1)\n",
    "    preds = preds.cpu()\n",
    "    total += len(outputs)\n",
    "    correct += (preds == targets).sum()\n",
    "    if wrong_count < 5 and preds != targets:\n",
    "        wrong_classify.append([inputs.data, preds, targets])\n",
    "        wrong_count += 1\n",
    "accuracy = 100 * correct.double() / total\n",
    "print('Accuracy of the model on the 10000 test images: %.2f %%' % (accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:29:00.226023Z",
     "start_time": "2019-11-06T14:28:59.491113Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAC/CAYAAACPMC8KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xu8FXW9//H3B9iAXAVB5CaYgoppWOTlaGaZpqLHtDqm5S0Ts9QsrbxUdrr8tDQr02NhEmpeuqjJEUnMk1mKF1QyEbkoolwFUUFA2Gw+vz9msCXf2ezZ6zqz1uv5eOwH7M+aNfOZvd971nzXWvNd5u4CAAAAAGRHh1o3AAAAAAB4NwZqAAAAAJAxDNQAAAAAIGMYqAEAAABAxjBQAwAAAICMYaAGAAAAABnDQK1BmdlEM/tBrfsA2sPMHjSzL9S6DyAtM/uumf221n0AaXF+gLyp53ODhhmomdlLZvaxKm9zvJnNNrNNZnZqNbddCjPbz8zuN7OVZrbczP5gZgNr3VcjqlFu3czWmNlb8devq7n9YlnkEjN72cxWmdntZtar1n01mhpl9qNm9lT8e3/RzMZVc/ulMLOOZvYDM1tsZqvN7Gkz27bWfTWSamfWzEaa2d3x4+tKM7vPzHat1vZLYWafLXhseMvM1saPGR+odW+NpBbH2YJtnxL/znMxMMr7uUHDDNTaYmadKrDaf0r6kqSnKrDud1Sg9z6SxksaLmmYpNWSflPmbaAMKpRbSXqfu/eIvypyMK5A7ydLOknSAZIGSdpG0i/KvA2UqNy/dzNrknSXpF9J6i3peElXmdn7yrmdeFuV+Hv7b0n/IWl/Sb0UZfjtCmwHRarA731bSZMk7SppgKTHJd1d5m1IKn/v7n5LwWNDD0XnOC+qwuc5aJ9KnRuYWR9JF0maWYn1x9vg3KBAQwzUzOxmSTtK+t/4GaBvmNnw+BmB083sZUn/Z2YHm9nCLe77zrMWZtbBzC40sxfM7DUz+72Z9W1tu+5+rbs/oCIedOPtXmRmz5nZ62b2GzPrGt92sJktNLNvmtlSxYMoMzvKzGaY2Rtm9oiZ7VWwvr3jZ5xXm9nvJHXdSt9T3P0P7r7K3ddKukZRwFFFtcptiT27mZ0bv6qxwsyuMLMO8W2nmtnDZvZTM1sp6btx/fNmNivO+X1mNqxgfYea2fNm9qaZXSPJtrL5oyXd4O6vuPtbkn4k6Xgz61aJfUWoRpntq2iAc7NHnpA0S9KolD0/aGaXmdnjcc7u3rytpN7j+n7xMfYNM/unmR1csL6dzOxv8bH2fkn9trLtPpLOk3SGuy+I+3/W3RmoVUktMuvuj7v7De6+0t2bJf1U0q5mtl3Knmt2fpDgFEk3ubu34z4oQY3PDS6TdLWkFe3smXODIjXEQM3dT5L0sqSj42eBflxw84cl7S7p4ylWda6kT8T3GSTpdUnXbr7RzJ4xsxPL1rj02bivnSWNlPStgtt2UHSCMkzSODN7v6QJks6UtJ2iZ5cnmVkXM+ss6U+Sbo7v8wdJnyzcUHzwPrCVPg5SBZ89QbIa5/YhM1tqZnea2fB2tn6spDGS3i/pGEmfL7htX0XPvm4v6Ydm9glJF0s6TlJ/SX+XdFvcVz9JdyjKfT9JL6jgCQMz2zHO7Y6bS3r3wdokdZE0op39o0i1yKy7L1OUmdMsehvh/oqOi/9oR+snK8rpIEkbFZ2IFHqndzMbLGmypB8oOp5eIOkOM+sfL3urpCcVZfb7ik5k37HF39ue8fY+Ff+9zTGzL7ejb5QoI+cHB0la6u6vtaP1mp8fxCfOB0m6qR19o0S1yqyZ7aPosf2XRbbOuUEx3L0hviS9JOljBd8Pl+SS3lNQO1jSwtbup+hZ2kMKbhsoqVlSpza2/Q9JpxbR7xcLvj9S0gsFfW6Q1LXg9uskfX+LdcxW9Ad4kKTFkqzgtkck/SBFH3tJWinpQ7X+HTbiVy1yG+els6K351wj6dm2Ml5wX5d0eMH3X5L0QPz/UyW9vMXyUySdXvB9B0lrFZ1gnCzp0YLbTNJCSV9oZdtfkDQn/hn1VvTWIpe0f61/j430VaPMHi1pmaJBz0ZFr1Cl7fdBSZcXfD8qPr52bKX3byp69a5wHfcpGpDtGG+/e8Ftt0r6bSvbPjFe/w2K3o6zl6Tlkg6t9e+xkb5qkdmC5YZIWiTphHb2m4Xzg29LerDWv79G/Kp2ZuPj4XTFj6fxcTPxsbiVfjk3KPKrIV5Ra8Mr7Vh2mKS74pH6G4pC3qLoPeaVUNjbAkXPeGy23N/99phhks7f3Fvc39D4PoMkLfI4sQXr2yoz20XRH8tX3P3vxe4EKqJiuXX3h9x9g7u/IekrknZS9AxdMb1tmdst+x4m6ecFva1UdNAdHN/vneXj/G5tvycoesbtQUWvAP81ri9s7Q6oqopk1sx2k/Q7RQ/enSXtIekbZja2yN4WSGrSu9+yWHj7MEmf3uJYe6Cik5xBkl539zVbrK816+J/v+fu69z9GUm3KzrxRu1V9PwgfhV2qqT/cffbSuit6ucHsZMl3djOvlFZlcrslyQ94+7TytQb5wYpNdJArbX3TxfW10h65z2rZtZR0Uuum70i6Qh337bgq6u7Lyp/u5KiA+lmOyp61muzLffnFUk/3KK3bvHBf4mkwWZW+NLvjtqK+C0Nf1H0LNzNxe8CSpSF3Lq2/v7vLbU3t2du0ds27v6Ioty+s644v0PVCnff5O6Xuvtwdx+i6IC8KP5C9VQ7s++VNNvd74szMFvRWxOPaEfPW2a2We++BqOw91cUvaJW2Ft3d79cUWb7mFn3LdbXmmcS1o/qq/px1qLrE6dKmuTuPyyi55qdH8T9b56Y4Y9F9I7SVTuzh0g6Nn6L9lJFEyD9JL4+LC3ODYrQSAO1ZZLe08YycyR1NbOxFs0k9i1F72Pd7JeK3js7TIqeDTOzY1pbmZl1tugCX5PUZGZdCy6ePNjM2npw/rKZDYkv7rxY0bPGrble0hfNbF+LdI/3o6ekaYrejnOumXUys+Mk7bOVvgcrumj+Wncv9r3IKI+q5tbM9jCz0fG1Pj0k/UTRwWxWfPupZvZSG/183cz6mNlQRa/IbS23v5R0kZntEa+/t5l9Or5tsqQ9zOw4i2aBOlfRtReJzKyvme0c53+UpKsUvVKxqY1+UV7VPtY+LWmERVP0m5ntLOkoRbPuFk4IMnwr/XzOzEZZdHH59yT90d1bWln2t5KONrOPx38nXePj+RB3X6Do7UH/HR//D1T0tsxE7v6ComsvLrHoeqHdFc1aec9WekX5Vfs420vR22UfdvcLE27P7PlBgVMk3eHuq1Msi/Kr9nH2VEXvrBkdf01XNGPtJfF9OTeokEYaqF0m6VsWvYx6QdIC7v6mopd3f63o5HSN3v3S6M8Vvbd1qpmtlvSoogsgJUlmNtPMPluw/FRFb235D0XT3a9T9H5wKRr9t/US8q3xOl6Mv1r9AEp3ny7pDEXXFL0uaZ6iPyy5+wZFF2SeGt92vKQ7C+9v0cxBH4q//YKiA8ClVvB5KW30isqodm4HKDp4rlKUueGSjvJoZjIpyu3DbfR8t6LJFGYoOqDe0NqC7n6XohmYbjezVYquhzsivm2FpE9LulzSa4ou/H1n2xZdMPyW/fuC4X6S7o33f4qkCe4+vo1eUX5VzWw82Pm8oglAVkn6m6ILzTfnbqiit9ls7dnTmyVNlLRU0Yx357a2oLu/ouhC+IsVXU/2iqSv69+PpyfGva6UdKm2mGgh4XHiBEVv83lN0d/Ltz2aLRjVU+3j7LGSPqhoApzCzyTbfCzL8vmBLHoC+r/E2x5rqdrH2TfcfenmL0XXQa6KtyFxblAx9u63JaNaLPoQ4T+4+32t3P6Sogsj/1LVxoCtMLOpiq5ZnNXK7S5phLvPq25nQDIz+5aia3Z+1crtDyqa7CMXH+yO+sf5AfKGc4PKqdSH5aINXqEPEQYqyd0Pq3UPQHu4e6uvNABZxPkB8oZzg8pppLc+AgAAAEAu8NZHAAAAAMgYXlEDAAAAgIwp6Ro1Mztc0awxHSX9Ov4cmVZ1ti7eVd23tgjQqre1Rht8fXs+zytAZlFN5cis1L7cklmUarVeX+Hu/dtesnVkFtVU7cxK5BalSXt+UPRAzaIPzrtW0qGKpvt8wswmuftzrd2nq7prXzuk2E2iwT1W4ozVZBbVVmpmpfbnlsyiVH/xPy4o5f5kFtVW7cxK5BalSXt+UMpbH/eRNM/dX4w/h+N2RZ8tA2QVmUUekVvkDZlF3pBZZFIpA7XBij7oc7OFce1dzGycmU03s+nNWl/C5oCSkVnkUZu5JbPIGDKLvOH8AJlUykAt6X2VwRSS7j7e3ce4+5gmdSlhc0DJyCzyqM3ckllkDJlF3nB+gEwqZaC2UNLQgu+HSFpcWjtARZFZ5BG5Rd6QWeQNmUUmlTJQe0LSCDPbycw6S/qMpEnlaQuoCDKLPCK3yBsyi7whs8ikomd9dPeNZna2pPsUTWU6wd1nlq0zoMzILPKI3CJvyCzyhswiq0r6HDV3v1fSvWXqBag4Mos8IrfIGzKLvCGzyKJS3voIAAAAAKgABmoAAAAAkDEM1AAAAAAgYxioAQAAAEDGMFADAAAAgIxhoAYAAAAAGcNADQAAAAAyhoEaAAAAAGQMAzUAAAAAyBgGagAAAACQMQzUAAAAACBjGKgBAAAAQMZ0qnUD+LeOo0Ym1hcd2i+odTt8WVCb9r47gtrIG88KajtdNK2I7gAg/+Zc/8HE+vyx1we1cxeHy84e01z2ngAASMIragAAAACQMQzUAAAAACBjGKgBAAAAQMYwUAMAAACAjClpMhEze0nSakktkja6+5hyNJVn68eGF5+/cmjHoLbn3vOD2mXDJiauc5emLqm23eypFmt45BZ5Q2aL8+a9uwS1+aPDSUNac/WgJ4LaxzW6pJ4aRSNntkP37kFt5XF7hbUj1gW1uQdPDGr7PP3pxO00T+kf1AZMezOo2fMvBbVNa9YkrrORNXJmkV3lmPXxI+6+ogzrAaqJ3CJvyCzyhswib8gsMoW3PgIAAABAxpQ6UHNJU83sSTMbl7SAmY0zs+lmNr1Z60vcHFAWW80tmUUGkVnkDZlF3nBOi8wp9a2PB7j7YjPbXtL9Zva8uz9UuIC7j5c0XpJ6WV+uokIWbDW3ZBYZRGaRN2QWecM5LTKnpIGauy+O/33VzO6StI+kh7Z+rxoxC0oduoSTdCw7de/Eu7cc/kZQu2TUvUHt4G0eDmp9OnRN06GaPfnX8dam8FmbHh3STTCCUK5yW0c6jnhPUGveoXdQm3daOPlOh84tiev81B5PB7UfDZgR1EY98rmgNvRTzyauM4vIbNvmXB9O5JQ0ccjktcnH4+/POSqoLVvUJ6iNVDjBCEKNnNk5P9gzqD3/X9ekum/SpGAPj749eeGU89rs8ffTgtp7PvdcUPONG9OtsE41cmYrYe7V+ybWt50Vvpnv5m9eFdQu+M8wt5ueeb70xnKm6Lc+mll3M+u5+f+SDpOUnzMfNCRyi7whs8gbMou8IbPIqlJeURsg6S6LXqnqJOlWd/9zWboCKofcIm/ILPKGzCJvyCwyqeiBmru/KOl9ZewFqDhyi7whs8gbMou8IbPIKqbnBwAAAICMKccHXtdUh27dgtrKT4ZPirx+5NqgNvOgCQlrDCcDaZ/wQvXH14cTmXzuL2cGtSF/Th4393pycVD70yN/KqI3NJLlX9w/qLV0CbM4+KZZ4XKvvx6usEM4yceGw96fuO2VuzYFtSvOCSd2+Mg2byfevxQPrOsc1Jo35P5QhwIrzgyzPX/sdanue8U5JyXWe08JJwkJp7oB3m392HASm6nHXZmw5DaVb6YVMz/0m6A2+uvnBLUhlz1SjXZQh1Yfv19Qe/LYcIIQSVr6n2Ft987huTwivKIGAAAAABnDQA0AAAAAMoaBGgAAAABkDAM1AAAAAMiY3F9h/8J3wolDZp50TVW2Pad5Q1Ab9/xng1qP7/YIaiMfDS9cb9Wwoe3qq1CvF4u+KzLo9VPCSRTWDAonCJGkGWf/ItU6V5y/Lqi1pOxn2w7Jk+90sXAykSRT13UPatPeGhHU/jjpwMT7b/9U2Gn3qeFnlL5n7YxU/SB7Ou66S1D73tfDyRGS7Pbrs4LasCnTSu4J2Gzxh8LTqB07pZs4ZPybw4Pa2k3hZEgn9Ppn4v0HdCx+gpKdDw9PDtZfVvTq0EA67rFrUPvGD34b1C5YdFji/ee+2T+onbJjeFz25+YV0V394RU1AAAAAMgYBmoAAAAAkDEM1AAAAAAgYxioAQAAAEDGMFADAAAAgIzJ/ayPzX03BrX13hzU1no4O9x+k74W1IbfnTzfXdf5K4OabQi302NB+adZnHvWkFTLXfHaqKDW/+ang9qmkjtCNaw+fr+g9uAPfx7UmqxjSdvpV8LMYe3x4a99Oaj1+dv8oLZx6bKgNkzpZ+oj3/Vl5U/D2thubwe1yWu7BrX33LI8qKWd0RRIZdjaVIslzXA7+cO7BbWW5WFmbz37G4nr3P7Yl4PaPbvdnaqfS3ecFNS+vcvxicu2zAuP02hcCz++XVA7bJvwHPmrs8LZISXpa/v8JajNXx/OBOkbw/P7RsQragAAAACQMQzUAAAAACBjGKgBAAAAQMa0OVAzswlm9qqZPVtQ62tm95vZ3PjfPpVtE2gfcou8IbPIGzKLvCGzyJs0k4lMlHSNpJsKahdKesDdLzezC+Pvv1n+9to2ctwTQe2wz5wX1LquDC9KHDH1sdTbqeUF6Icd8lSq5W6ctW9QG/72M+VuJy8mKsO5TWPNDuHzKEkTh3zw8nMS7z/4rgVl76kUPReFf28b3WvQSWZNVM4zWwnfHnlPquWuOOekoNZldvj4gLKaqAbPbN/J4WRMe849O6jtfF04IUfL8qWptrH9NY8k1jvdHU40tn5aOMlZF2sKant1Dh9Llh+0Q+J2+tbXZCIT1eCZbY8O7w0nvPnNOT8LanvcE2Z+5BcfT1znTfeE56orFm4b3l/J9280bb6i5u4PSdpyOpdjJN0Y//9GSZ8oc19AScgt8obMIm/ILPKGzCJvir1GbYC7L5Gk+N/ty9cSUDHkFnlDZpE3ZBZ5Q2aRWRX/HDUzGydpnCR1VbdKbw4oGZlF3pBZ5A2ZRR6RW1Rbsa+oLTOzgZIU//tqawu6+3h3H+PuY5rUpcjNAWWRKrdkFhlCZpE3ZBZ5wzktMqvYV9QmSTpF0uXxv3eXraMy6Hn7o7VuoSgd90j+FPdvDLghoRpewOzzu5e5o7qT6dwW68dfuT6x/rPbDgpqLSteq3Q7KK+6zGxrVpy5f1Ab221GUDt38QeDWpcp5Z84ZP0R4XYWHBcuN2Dw60Gt71eT19kye16pbWVdQ2W292/D843eCcuF05mVbuMrC4PaXg98KajN/ljyYwTe0VCZbVWHcIKZBd8NhwkvNfcLaqN+tCyoVSLzjSjN9Py3SZomaVczW2hmpysK86FmNlfSofH3QGaQW+QNmUXekFnkDZlF3rT5ipq7n9DKTYeUuRegbMgt8obMIm/ILPKGzCJvir1GDQAAAABQIQzUAAAAACBjKj49P9JbfMh2ifWBHcOJQ5IMm/J2OdtBjXXc4KmW+8g2yb/3n3ftWs52gMx4+PoxQa2fppW0zo677hLUvv6Lm4Pa2G7pjrMfOPisxHq/+p9MBDW02/kLgtpNfx8c1E7utSiorf/PN5JXOqHktpADvv+eQe25/5gY1Pb8WThhzaD5j6TeziW73hvUvrrwxNT3bzS8ogYAAAAAGcNADQAAAAAyhoEaAAAAAGQMAzUAAAAAyBgmEwEyqv914eQIyy5aF9QGtDLZzIbh/YNah4XhBeRAVjQdszzVcmvCuRHUr8Rt73JLOAlD2olDJq8NJ+7p+fLGEjsC2m/dB3YKasf3nJywZFNQGdL7zcR1tpTaFLKnQ8eg9MbFa4Lat14NJxgZfNXjQS3d1GeRfbssDWpNr4f9IMIragAAAACQMQzUAAAAACBjGKgBAAAAQMYwUAMAAACAjGEyESBHPvboWUHtXwdMTFz2xePCCQ52eSThgt1NJVwqnnBBsiR16BxeqL5x392DWsdHZgY1b95QfD/Itea7wwlwNDosPf+F64Labgr/NronzJ1zwBnTE7d99aAn2uxPSp445NqxRwW1LrPTrQ8op01N4fPvXSw8HidZfvOwxHpfLS6pJ2RPh/eOCGrTRt8a1D50/peCWs+Nj5a9n4GPpDsPsabOQa3ezxl4RQ0AAAAAMoaBGgAAAABkDAM1AAAAAMgYBmoAAAAAkDFtDtTMbIKZvWpmzxbUvmtmi8xsRvx1ZGXbBNIjs8gjcou8IbPIGzKLvEkz6+NESddIummL+k/d/cqydwSUbqLqNLO97+kRFg9IXvb5468NasePOTyozVwyMNW2t7urW1DrcvqSxGWnjrozofpwULlt9YCgdvtxHw1qLc/NabvB/JuoOs1tWv1+NS2o7TY4nM0xadbHpFolfH9OOMNj36psOZMmqsEzm1cLN64LatvOCWt1aKLIrBYcHR61/rwufIzvPemZoLapAv20dLGg9sJP9gtq2ydMptvz9vLPQpklbb6i5u4PSVpZhV6AsiCzyCNyi7whs8gbMou8KeUatbPN7Jn4ZeQ+rS1kZuPMbLqZTW/W+hI2B5SMzCKP2swtmUXGkFnkDecHyKRiB2rXSdpZ0UeRLpH0k9YWdPfx7j7G3cc0qUuRmwNKRmaRR6lyS2aRIWQWecP5ATKrqIGauy9z9xZ33yTpekn7lLctoLzILPKI3CJvyCzyhswiy9JMJhIws4HuvnkWgWMlPbu15VEdq4eEz+70rkEfWVQvme17R3hh7+hjTk5cdsZ+W14rLf1u5z+HC+6ccuMHplyuHU7ouSyobTdpUlC74tyTEu/f5d6EK4vrSL3kthTDvhNOMHLwtDOC2uodw4ezA86YHtSuHpScmZ0mh+vssiRcZ9KkJZMndw1qV5zTSmankFmUSYeOQWndWa+nuuuCjb2CWsf1LYnLevu6yp1GzOzQB94KamO//HZQmz7tjaB266QPB7Vt54bbePWg5sRt9+7weFC78sf/E9QunX9MULMLXw1q9Z7PNgdqZnabpIMl9TOzhZIulXSwmY1W9PN5SdKZFewRaBcyizwit8gbMou8IbPImzYHau5+QkL5hgr0ApQFmUUekVvkDZlF3pBZ5E0psz4CAAAAACqAgRoAAAAAZExRk4kgm/Y4J7z+ddHvOwc1b95QjXZQAZvWrAlqQz79fOKyx2x3eFB76YsjglpL5/BS3N7vXxHUThweToJw9ZMfTdx217nh5Ap9nwsvVN//kvCi4v83IJwAYvlVdydu5/cLPhLUWmbOTlwW9SNpQo5uu+4S1K6+NFxu8towm5K0+5XhJAwts+cFtZ0GhpOOzB97fVA7+7jEzWjklOQ68qljQu6WH9A/qG1KONtad9jqoNb8Qs/wvgnHaEnabuRrQe3h0bcnLrulcb8PL8Pa6Ylw4h7Uqcf+FZQ+ctoXglr3ixYFtf89+cpUm+hmybnt1qFHUDvllrOC2vBLw/MD35Q84U094xU1AAAAAMgYBmoAAAAAkDEM1AAAAAAgYxioAQAAAEDGMJlIhuxwXTiJgiTN+9r6oLZLU5eg9suhfwtqx2yTMNkCk4nUl1Yurm1ZvjyoDf1+WEtrirYNaiP0VNHrk6Tnpg8Nagsf+ntQ+2zPJYn3//4ZfYLaLueV1BJyatYFYRaSnH/raYn1YbPTTaQw8oxwghItDktJE4xI0pG7fiqoJU1agnx4+bLwsfjpfa8pfoX7l9BMK36zKjzO7jJhWVBrvGkaGpiHE310vi88B22+L7zrOTog1SbWHrdvYv3v1/wqqO3weEL6GnDikCS8ogYAAAAAGcNADQAAAAAyhoEaAAAAAGQMAzUAAAAAyBgmE8kQb2WSj8/MOD2oTf/gbyvdDlBxGxe8EtTedp4/Qvv1nZ7wcDY2LG07pvgJdVpz7uIPBrWrByVMOiLpxc/2D2rDvsNkInl15V5/rHULbfrd2UcEtU5zn6xBJ2gki8K57CRJ85vfCmrd/m9mUNtU7oZyijMiAAAAAMgYBmoAAAAAkDEM1AAAAAAgYxioAQAAAEDGtDlQM7OhZvZXM5tlZjPN7Ctxva+Z3W9mc+N/+1S+XaBtZBZ5Q2aRR+QWeUNmkTdpZn3cKOl8d3/KzHpKetLM7pd0qqQH3P1yM7tQ0oWSvlm5VhvX6gW9w2I4yViit/cbGdSapk4vsaPMI7PIGzKbEwu+t39Qu2/QdUEtaSZISRr2nWll76mGyG0OdNjI/HkFyGyV+DYtifVmWVDbtGZNpdvJrTZfUXP3Je7+VPz/1ZJmSRos6RhJN8aL3SjpE5VqEmgPMou8IbPII3KLvCGzyJt2XaNmZsMl7S3pMUkD3H2JFAVf0vat3GecmU03s+nNWl9at0A7kVnkDZlFHrU3t2QWtcaxFnmQeqBmZj0k3SHpPHdflfZ+7j7e3ce4+5gmdSmmR6AoZBZ5Q2aRR8XklsyiljjWIi9SDdTMrElRoG9x9zvj8jIzGxjfPlDSq5VpEWg/Mou8IbPII3KLvCGzyJM2JxMxM5N0g6RZ7n5VwU2TJJ0i6fL437sr0mG5WHjx4pKvhReFr96tufyb7hJeUPmZPcMJPb7Z/7HE+zfp8YRqx1TbvmfC/wS1SWsGJC570UOfSrXOJP2mhVHqO6E2F83XTWYbQIfu3YNaR/PU99/u6fDvOo/IbOl6vrwx1XLfHnlPYv07Z54W1A44IzxOJ00ckmTqvWMS68NUP5OJkNvq+M2qoYn1n838aFD75/43JiyJzchs9TQtb6p1C3UhzayPB0g6SdK/zGxGXLtYUZh/b2anS3pZ0qcr0yLQbmQWeUNmkUfkFnlDZpErbQ7U3P0fUsJcmpFDytsOUDoyi7whs8gjcou8IbPIm3bN+ggAAAAAqDwGagAAAACQMWmuUasLHd63e1B78mu/qEEnW9O57GtssnDSkU/2WJG47CeP/GXR2xnZaVxQ6zuh6NWhQSw95X1BbadODwVB3DU/AAAH7klEQVS1J1v5uJr+f1kQ1NJNKYF602XKE0FtvxnhBEmPjv5j4v3HXppukpAkSdsZ9p36mTQE0srTwsnHJOmQbZ4s63Y+8MTngtqQC95OXHb4qy8HtT3PPyeo7fzSK0GN4yQqbfDfklPWdGI4YVingTsEtY1Llpa9pzziFTUAAAAAyBgGagAAAACQMQzUAAAAACBjGKgBAAAAQMY0zGQiq0b0LPs6kyY4WLWpa1A7+4kTg1rz26X96L+13+Sgtjph21NOPyiorR20TeI6Fx6d7vLiQVPCT5sf9diioMbFymiLp/wz+Ors/0qs91r0Qhm7Qb3p+9WwttMFZyQue/TeM4La468OC2pdf9EnqPVOmMgE9eWtYa199Fbxjp17VFAb8vXwxKJl3vzU6xx26SNBjcdi1ELnPycfF9d6OMndW2N2DGpd/5fJRCReUQMAAACAzGGgBgAAAAAZw0ANAAAAADKGgRoAAAAAZEzDTCbS4w+PBbWj/vCBqmx7J/2z7Ov8vcJPcU/2TFDp1sqSI+8suh0uVkabOvbqFdSOPf3BdHe+qX8rNzCZCFrXMnteUBuZPJeIZifUeiu8PxpT/xnJj3ILN64LakM6hRN27XrXl4LayPOeDGq+kUdT1LdTnz0lqHXYNpxgJJwerzHxihoAAAAAZAwDNQAAAADIGAZqAAAAAJAxDNQAAAAAIGPanEzEzIZKuknSDpI2SRrv7j83s+9KOkPS8njRi9393ko1CqRFZrOpZdTwoHZxv78Gtb+uCy8h7jN1TvI6S+4qG8gs8qbRMrvNnx5PrH/xTwemuv8IhROaeUkdoRiNltss6nRL36C29JDmoLbtzdXoJvvSzPq4UdL57v6UmfWU9KSZ3R/f9lN3v7Jy7QFFIbPIGzKLvCGzyCNyi1xpc6Dm7kskLYn/v9rMZkkaXOnGgGKRWeQNmUXekFnkEblF3rTrGjUzGy5pb+md1/DPNrNnzGyCmfVp5T7jzGy6mU1v1vqSmgXai8wib8gs8obMIo/ILfIg9UDNzHpIukPSee6+StJ1knaWNFrRsxM/Sbqfu4939zHuPqZJXcrQMpAOmUXekFnkDZlFHpFb5EWqgZqZNSkK9C3ufqckufsyd29x902Srpe0T+XaBNqHzCJvyCzyhswij8gt8iTNrI8m6QZJs9z9qoL6wPi9vpJ0rKRnK9Mi0D5kNqMefSYoHTX4AynvvLK8vWQMmUXekFnkEbmtvV63PZpQq0EjOZFm1scDJJ0k6V9mNiOuXSzpBDMbrWiG2ZcknVmRDoH2I7PIGzKLvCGzyCNyi1xJM+vjPyRZwk18vgQyicwib8gs8obMIo/ILfKmXbM+AgAAAAAqj4EaAAAAAGQMAzUAAAAAyBgGagAAAACQMQzUAAAAACBjGKgBAAAAQMYwUAMAAACAjDF3r97GzJZLWhB/20/SiqptvLLqaV+k7O7PMHfvX80NktncyOr+kNnyqad9kbK9P1XNbR1nVqqv/cnyvtTyWJvln0sx6ml/srwvqTJb1YHauzZsNt3dx9Rk42VWT/si1d/+lEs9/VzqaV+k+tufcqmnn0s97YtUf/tTLvX2c6mn/amnfSmnevu51NP+1MO+8NZHAAAAAMgYBmoAAAAAkDG1HKiNr+G2y62e9kWqv/0pl3r6udTTvkj1tz/lUk8/l3raF6n+9qdc6u3nUk/7U0/7Uk719nOpp/3J/b7U7Bo1AAAAAEAy3voIAAAAABnDQA0AAAAAMqbqAzUzO9zMZpvZPDO7sNrbL5WZTTCzV83s2YJaXzO738zmxv/2qWWPaZnZUDP7q5nNMrOZZvaVuJ7L/akUMpsdZDYdMpsdZDa9POe2njIrkdu08pxZqb5yW6+ZrepAzcw6SrpW0hGSRkk6wcxGVbOHMpgo6fAtahdKesDdR0h6IP4+DzZKOt/dd5e0n6Qvx7+PvO5P2ZHZzCGzbSCzmUNmU6iD3E5U/WRWIrdtqoPMSvWV27rMbLVfUdtH0jx3f9HdN0i6XdIxVe6hJO7+kKSVW5SPkXRj/P8bJX2iqk0Vyd2XuPtT8f9XS5olabByuj8VQmYzhMymQmYzhMymluvc1lNmJXKbUq4zK9VXbus1s9UeqA2W9ErB9wvjWt4NcPclUhQUSdvXuJ92M7PhkvaW9JjqYH/KiMxmFJltFZnNKDK7VfWY27r4HZPbVtVjZqU6+B3XU2arPVCzhBqfD1BjZtZD0h2SznP3VbXuJ2PIbAaR2a0isxlEZttEbjOI3G4Vmc2geststQdqCyUNLfh+iKTFVe6hEpaZ2UBJiv99tcb9pGZmTYoCfYu73xmXc7s/FUBmM4bMtonMZgyZTaUec5vr3zG5bVM9ZlbK8e+4HjNb7YHaE5JGmNlOZtZZ0mckTapyD5UwSdIp8f9PkXR3DXtJzcxM0g2SZrn7VQU35XJ/KoTMZgiZTYXMZgiZTa0ec5vb3zG5TaUeMyvl9Hdct5l196p+STpS0hxJL0i6pNrbL0P/t0laIqlZ0bMpp0vaTtFMMnPjf/vWus+U+3Kgopfpn5E0I/46Mq/7U8GfE5nNyBeZTf1zIrMZ+SKz7fpZ5Ta39ZTZeH/IbbqfU24zG/dfN7mt18xavHMAAAAAgIyo+gdeAwAAAAC2joEaAAAAAGQMAzUAAAAAyBgGagAAAACQMQzUAAAAACBjGKgBAAAAQMYwUAMAAACAjPn/dI964H92hF4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1080x5400 with 5 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(15, wrong_count*15))\n",
    "for i, (img, preds, truth) in enumerate(wrong_classify):\n",
    "    img = img.reshape(28, 28).cpu().numpy()\n",
    "    plt.subplot(1, wrong_count, i+1)\n",
    "    plt.imshow(img)\n",
    "    plt.title('true:%i, pred:%i' % (truth, preds))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-03T07:12:44.915877Z",
     "start_time": "2019-11-03T07:12:44.913744Z"
    }
   },
   "source": [
    "### Save Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-06T14:29:00.277464Z",
     "start_time": "2019-11-06T14:29:00.228472Z"
    }
   },
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'rnn_cuda.pkl')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "oldHeight": 511.99678000000006,
   "position": {
    "height": "533.984px",
    "left": "411.875px",
    "right": "20px",
    "top": "161.984px",
    "width": "760.516px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "varInspector_section_display": "block",
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
